#!/usr/bin/python
# -*- encoding: utf-8 -*-

import numpy as np
from model import BiSeNet

import torch

import os
import os.path as osp

from PIL import Image
import torchvision.transforms as transforms
import cv2
from pathlib import Path
import configargparse
import tqdm

# import ttach as tta

def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
                     img_size=(512, 512)):
    im = np.array(im)
    vis_im = im.copy().astype(np.uint8)
    vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
    vis_parsing_anno = cv2.resize(
        vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
    vis_parsing_anno_color = np.zeros(
        (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255])  # + 255

    num_of_class = np.max(vis_parsing_anno)
    # print(num_of_class)
    for pi in range(1, 14):
        index = np.where(vis_parsing_anno == pi)
        vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])

    for pi in range(14, 16):
        index = np.where(vis_parsing_anno == pi)
        vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
    for pi in range(16, 17):
        index = np.where(vis_parsing_anno == pi)
        vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
    for pi in range(17, num_of_class+1):
        index = np.where(vis_parsing_anno == pi)
        vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])

    vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
    index = np.where(vis_parsing_anno == num_of_class-1)
    vis_im = cv2.resize(vis_parsing_anno_color, img_size,
                        interpolation=cv2.INTER_NEAREST)
    if save_im:
        cv2.imwrite(save_path, vis_im)


@torch.no_grad()
def evaluate_one(dspth='./data', cp='model_final_diss.pth'):
    print(f'[INFO] loading face parsing model...')
    n_classes = 19
    net = BiSeNet(n_classes = n_classes)
    net.cuda()
    net.load_state_dict(torch.load(cp))
    net.eval()

    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    image_path = dspth

    assert image_path.endswith('.jpg') or image_path.endswith('.png')

    img = Image.open(image_path)
    ori_size = img.size
    image = img.resize((512, 512), Image.BILINEAR)
    image = image.convert("RGB")
    img = to_tensor(image)

    # test-time augmentation.
    inputs = torch.unsqueeze(img, 0)  # [1, 3, 512, 512]
    outputs = net(inputs.cuda())
    parsing = outputs.mean(0).cpu().numpy().argmax(0)

    vis_parsing_ann = np.zeros_like(parsing)

    for pi in range(1, 14):
        index = np.where(parsing == pi)
        vis_parsing_ann[index[0], index[1]] = 1

    return vis_parsing_ann


def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):

    Path(respth).mkdir(parents=True, exist_ok=True)

    print(f'[INFO] loading model...')
    n_classes = 19
    net = BiSeNet(n_classes=n_classes)
    net.cuda()
    net.load_state_dict(torch.load(cp))
    net.eval()

    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    image_paths = os.listdir(dspth)

    with torch.no_grad():
        for image_path in tqdm.tqdm(image_paths):
            if image_path.endswith('.jpg') or image_path.endswith('.png'):
                img = Image.open(osp.join(dspth, image_path))
                ori_size = img.size
                image = img.resize((512, 512), Image.BILINEAR)
                image = image.convert("RGB")
                img = to_tensor(image)

                # test-time augmentation.
                inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
                outputs = net(inputs.cuda())
                parsing = outputs.mean(0).cpu().numpy().argmax(0)

                image_path = int(image_path[:-4])
                image_path = str(image_path) + '.png'

                vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)


if __name__ == "__main__":
    parser = configargparse.ArgumentParser()
    parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
    parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
    parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
    args = parser.parse_args()
    evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
